Grokking Modular Addition¶
Dani Balcells, April 2nd 2024
A few weeks ago, a bunch of us at Recurse Center finished going through most of the materials for the ARENA course, which Changlin Li kindly walked us through. Régis Schiavi and I decided to keep working on one of the optional tracks - replicating the results of the 2023 paper by Neel Nanda et al. titled “Progress Measures for Grokking Via Mechanistic Interpretability”. For several reasons, this bit of work felt like a step up in difficulty compared to the rest of the course, so it felt worth writing a blog post about.
This notebook borrows heavily from the ARENA course contents, specifically from the section on grokking and modular arithmetic.
Setup¶
# If necessary, install requirements from repository root
# !pip install -r ../requirements.txt
import torch as t
from torch import Tensor
import torch.nn.functional as F
import numpy as np
from pathlib import Path
import os
import sys
import plotly.express as px
import plotly.graph_objects as go
from functools import *
import gdown
from typing import List, Tuple, Union, Optional
from fancy_einsum import einsum
import einops
from jaxtyping import Float, Int
from tqdm import tqdm
from transformer_lens import utils, ActivationCache, HookedTransformer, HookedTransformerConfig
from transformer_lens.hook_points import HookPoint
from transformer_lens.components import LayerNorm
from my_utils import *
if t.backends.mps.is_available():
device = t.device('mps')
elif t.cuda.is_available():
device = t.device('cuda')
else:
device = t.device('cpu')
Motivation: what is "grokking"?¶
The main reason we conduct this study is to explore the concept of “grokking” in neural networks—the modular arithmetic task, which we’ll explain in detail in the next section, just happens to offer a convenient way to study grokking.
But what is grokking? Let’s take a step back and consider the concepts of generalization and overfitting, which are central to virtually all forms of supervised machine learning. In supervised learning, we have a set of labeled examples (for example, pictures of animals labeled “cat”, “dog”, “parrot”) and we train an model to correctly map inputs to labels. Critically, though, we want it to be able to perform this task not only on the data we use during the training process, but also on new data—we want it to be able to generalize. Unfortunately, if we’re not careful, our algorithm will simply memorize the labels for the training set instead of learning a general rule, leading it to perform well on the training set but not on new data—we call this overfitting. A common way to avoid overfitting is to hold back part of our labeled data in a test set, which isn’t used to update the model’s parameters during training, and instead is only used to measure the model’s performance on unseen data (i.e. its ability to generalize).
In many common and desirable training scenarios, performance on the test set often roughly follows performance on the training set. This often indicates that the model is indeed learning a general rule, and performs only slightly worse on unseen data than on training data.
However, recent experiments have shown an unusual relationship between training and test loss in some scenarios. In an initial phase, loss on the training set drops quickly while loss on the test set remains high, suggesting that the model has “memorized” the relationship between inputs and outputs instead of learning a general rule. When confronted with this behavior, we’d perhaps assume the training isn’t working, and start over after making some changes in the design of our model or training process. It turns out, though, that in some cases, as in the modular addition case, if given just the right amount of training data, after a considerable number of training epochs (about 10,000 in our case) during which training loss remains low and test loss remains high, test loss suddenly drops, indicating that the model has learned a general rule long after memorizing the training set. We call this sudden generalization long past the point of memorization “grokking”.
This behavior is puzzling, and prompts a number of intriguing questions: Why does the model generalize so far past the point of memorization? If the training loss is already near zero, what pushes the model to keep learning until it reaches a general solution? How does the transition from a memorized solution to a general one take place at the level of individual parameters and circuits inside the model?
It also has interesting implications from the safety point of view, as it exemplifies emergent behavior in deep learning models: qualitative leaps in capability that appear suddenly over a relatively short number of training epochs. Emergent behavior and this sort of phase change during training are examples of the “more is different” principle in machine learning, which describes how continuous, quantitative progress can lead to sudden, qualitative differences in outcomes and capabilities. This is important from a safety perspective because it makes it hard for us to predict future capabilities by linear extrapolation of current capabilities. What if, for example, self-perception or the ability to deceive were also emergent properties that arise in phase shifts? A better understanding of grokking might help us shed light on these questions.
The task: grokking modular addition¶
The task we’re focusing on in this post is modular addition: $z = (x + y) \% p$, that is, the residue of dividing $x + y$ by $p$. For example, $(70 + 50 )\% 113 = 120 \% 113 = 7$.
In our setup, the model learns a fairly constrained version of this problem: we’re fixing p to the value 113, and we’re limiting x and y to the integers between 0 and 113.
The model we are reverse-engineering is a transformer with a single layer, with four attention heads, no layer norm, d_model = 128, and d_mlp = 512. The diagram below is taken from the ARENA materials.

The model takes sequences of three one-hot encoded vectors as inputs: the first two correspond to the one-hot encodings of the inputs x and y, and the third one corresponds to the one-hot encoding of =. Each one-hot encoded vector has length p+1, allowing us to represent the integers between 0 and p-1, plus the token =.
The model outputs three logit vectors, one for each element in the input sequence. The p values in each logit vector represent a probability distribution over the integers between 0 and p-1.
We keep only the output token predicted for input token = as the result of the modular addition operation, that is, we feed the three-element input sequence forward through the transformer, and take only the last (i.e. the third) token of the output sequence as the result of the operation. We train the model on the cross-entropy loss between the logits for this last output token and the one-hot encoding of the correct result of the modular addition.
Note that both one-hot encoded inputs and logit outputs represent distributions over the discrete and relatively limited space of the integers between 0 and p-1. This means that we’re effectively treating the modular addition task as a classification task, from discrete and limited inputs to discrete and limited outputs. The general case of learning modular addition for two continuous and arbitrary inputs is far more complex and beyond the scope of our study, which is primarily concerned with grokking rather than modular addition.
Training is done by generating the full set of [x, y, '='] sequences for all values of x and y between 0 and p-1. Note that, for this toy model, the entire input space is fairly small—there are only $p \cdot p = 12,769$ possible inputs. We train on full batches, using 30% of the total sequences for our training set. We use AdamW with very high weight decay (wd = 1). Weight decay plays an important role in grokking, as we will later see.
Loading the model¶
The code below defines a HookedTransformer for our task as described above, and loads in the data from the entire training process.
p = 113
cfg = HookedTransformerConfig(
n_layers = 1,
d_vocab = p+1,
d_model = 128,
d_mlp = 4 * 128,
n_heads = 4,
d_head = 128 // 4,
n_ctx = 3,
act_fn = "relu",
normalization_type = None,
device = device
)
model = HookedTransformer(cfg)
large_root = Path('large_files')
if not large_root.exists():
os.mkdir(large_root)
from huggingface_hub import hf_hub_download
REPO_ID = "callummcdougall/grokking_full_run_data"
FILENAME = "full_run_data.pth"
hf_hub_download(
repo_id = REPO_ID,
filename = FILENAME,
local_dir = large_root,
)
full_run_data_path = large_root / FILENAME
full_run_data = t.load(full_run_data_path, map_location=device)
state_dict = full_run_data["state_dicts"][400]
model = load_in_state_dict(model, state_dict)
Visualizing the grokking curve¶
We can plot the loss over the train and test datasets to view the characteristic "grokking" curve.
lines(
lines_list=[
full_run_data['train_losses'][::10],
full_run_data['test_losses']
],
labels=['train loss', 'test loss'],
title='Grokking Training Curve',
x=np.arange(5000)*10,
xaxis='Epoch',
yaxis='Loss',
log_y=True
)
Looking around: everything is periodic!¶
We start by running all possible inputs through the model, and capturing all internal variables in a cache.
all_data = t.tensor([(i, j, p) for i in range(p) for j in range(p)]).to(device)
labels = t.tensor([fn(i, j) for i, j, _ in all_data]).to(device)
original_logits, cache = model.run_with_cache(all_data)
Using the transformer_lens library, we can access all activations and intermediate variables in the transformer for any input in our batch.
We start by having a look at the attention scores. Since, as described earlier, we only keep the logits for the final token in the output sequence, we index into the attention scores to consider only the attention paid by the final token to all tokens in the sequence (including itself), for each attention head, and averaged across all input sequences.
As the plot below shows, for all four attention heads, the final token = pays no attention to itself, and pays equal attention to the first two input tokens x and y . This makes intuitive sense: the third and final token = holds no useful information before the attention layer, whose job is precisely to move information from the other input tokens to the final position. Since addition is commutative, it makes sense that both x and y inputs get weighted equally by the attention layer.
attn_p = cache['blocks.0.attn.hook_pattern'][:,:,-1,:]
print(f'{attn_p.shape=}')
mean_final_pos_attn = attn_p.mean(0)
print(f'{mean_final_pos_attn.shape=}')
imshow(mean_final_pos_attn, xaxis='Position in input sequence', yaxis='Attention head no.',
title='Attention paid by final token to each token (average across all input)')
attn_p.shape=torch.Size([12769, 4, 3]) mean_final_pos_attn.shape=torch.Size([4, 3])
Now, for the conceptual leap that I found the most interesting (and hard to grok myself) in this whole exercise. Remember that we’ve fed the entire input space through the model, a total of 12,769 (or $p^2$ for $p=113$) input sequences [x, y, '='] for all possible values of x and y between 0 and 112 (or $p-1$). What if we took our batch dimension, of size 12,769, and rearranged it into two separate dimensions, each of size 113? This would allow us to consider, for example, the value of a certain neuron activation, instead of as a single vector $a_i$ of 12,769 values, as a 113x113 matrix $a_{xy}$ representing the same values as a function of the inputs x and y.
Why bother doing this? You’ll see that things get very interesting when we rearrange activations and attention scores this way.
For example, as you can see below, the activations of the first three neurons in the MLP all follow very periodic patterns as functions of the inputs x and y.
# Take MLP activations from cache
neuron_acts_post = cache['blocks.0.mlp.hook_post'][:, -1, :]
neuron_acts_pre = cache['blocks.0.mlp.hook_pre'][:, -1, :]
# Rearrange batch dimension (p^2) into two separate dimensions p, p
neuron_acts_post_sq = einops.rearrange(neuron_acts_post, "(x y) d_mlp -> x y d_mlp", x=p)
neuron_acts_pre_sq = einops.rearrange(neuron_acts_pre, "(x y) d_mlp -> x y d_mlp", x=p)
top_k = 3
inputs_heatmap(
neuron_acts_post_sq[..., :top_k],
title=f'Activations for first {top_k} neurons',
animation_frame=2,
animation_name='Neuron'
)
The same is true for attention scores, as you can see below. Note that we’re plotting only the attention paid by the final token to the token at position 0. As we saw earlier, the attention scores for positions 0 and 1 are almost identical (they are both close to 0.5, while the attention score for position 2 is 0). We can therefore expect (and verify) that the attention paid by the final token to the token at position 1 is a very similar periodic function of the inputs x and y.
attn_mat = cache['blocks.0.attn.hook_pattern'][:, :, 2, :]
attn_mat_sq = einops.rearrange(attn_mat, "(x y) head seq -> x y head seq", x=p)
inputs_heatmap(
attn_mat_sq[..., 0],
title=f'Attention score for heads at position 0',
animation_frame=2,
animation_name='head'
)
We can also follow linear paths within the transformer to see how each input token contributes to the input of a given neuron. Looking at the OV circuit, we take the tensor product $W_E \cdot W_V \cdot W_O \cdot W_{in}$ (i.e. the product of the embedding, attention head value, attention head output, and MLP input matrices), which gives us an effective tensor $W_{neur}$ of shape (d_head, p, d_mlp). We can think of this as a (p, d_mlp) matrix for each attention head, whose i-th column represents the contribution to the i-th neuron’s input as a function of the input token. Note that we’re not counting attention weights here. Again, we see that the model exhibits very periodic behaviors: attention weighting aside, as the value of the model’s inputs increases, the inputs to the MLP cycle through a certain periodic function.
# Get weight matrices from the model
W_O = model.W_O[0] # 0-th element here refers to the W_O matrix of the 0-th layer
W_V = model.W_V[0]
W_E = model.W_E[:-1]
W_in = model.W_in[0]
# Calculate effective matrix W_neur
W_neur = W_E @ W_V @ W_O @ W_in
print(f'{W_neur.shape=}')
top_k = 5
animate_multi_lines(
W_neur[..., :top_k],
y_index = [f'head {hi}' for hi in range(4)],
labels = {'x':'Input token', 'value':'Contribution to neuron'},
snapshot='Neuron',
title=f'Contribution to first {top_k} neurons via OV-circuit of heads (not weighted by attention)'
)
W_neur.shape=torch.Size([4, 113, 512])
Similarly, we can turn to the QK circuit and see that it also exhibits periodic patterns. First, we define the effective matrix $W_{attn}$ as the tensor product $r_{final\_pos}\cdot W_Q \cdot W_K^T \cdot W_E^T$, where $r_{final\_pos} = W_E (=) + W_{pos}(2)$, i.e. the initial value of the residual stream at the final position, that is, the embedding for the = token plus the positional embedding for position 2. This effective matrix, when multiplied on the right by a given input token a, gives us the attention score paid to token a by the = token in position 2. We’re essentially taking the matrix $v \cdot W_Q \cdot W_K^T \cdot w$, which gives us the attention score between the embedding vectors $v$ as query and $w$ as key, and establishing the query vector as the token = in position 2.
We again see that the attention scores are periodic functions of the input token.
# Get weight matrices from model
W_K = model.W_K[0]
W_Q = model.W_Q[0]
W_pos = model.W_pos
# Get query-side vector by summing the embedding for the '=' token plus the positional embedding for position 2
final_pos_resid_initial = model.W_E[-1] + W_pos[2]
# Calculate effective matrix W_attn
W_QK = W_Q @ W_K.mT
W_attn = final_pos_resid_initial @ W_QK @ W_E.T / np.sqrt(d_head)
print(f'{W_attn.shape=}')
lines(
W_attn,
labels = [f'head {hi}' for hi in range(4)],
xaxis='Input token',
yaxis='Contribution to attn score',
title=f'Contribution to attention score (pre-softmax) for each head'
)
W_attn.shape=torch.Size([4, 113])
Finally, as an aside that wasn’t included in this part of the exercise in the ARENA coursework, we can have a look at what the embedding matrix is doing. As a function of the value of the input token, we can easily see that the 0th and 1st (and all other) dimensions of the embedding space are highly periodic.
embedding_dims = [0, 1]
lines(
W_E[:, embedding_dims].T,
labels=[f'Embedding dim #{i}' for i in embedding_dims],
xaxis='Input token',
yaxis='Embedding dimension value',
title='Embedding dimension values as a function of the input token'
)
Analyzing in the Fourier basis¶
Thankfully, we have a tool in our belt that makes it easy to deal with periodic functions: the Fourier transform! Essentially, it is a basis change. It takes a sequence of length N and represents it as N Fourier coefficients (hence it being a proper basis change). The Fourier coefficients represent the amplitudes of a series of sine and cosine terms, with frequencies at integer multiples of $\frac{2\pi}{N}$, that can fully reconstruct the original sequence. Intuitively, this means that it can tell us, in a conveniently lossless way, what frequencies make up a given sequence.
The code below generates the 1D Fourier basis for a given sequence length p. As you can see, the resulting (p, p) matrix can be thought of as p sine and cosine terms of length p, with frequencies between 0 (the constant term) and $p/2$.
def make_fourier_basis(p: int) -> Tuple[Tensor, List[str]]:
'''
Returns a pair `fourier_basis, fourier_basis_names`, where `fourier_basis` is
a `(p, p)` tensor whose rows are Fourier components and `fourier_basis_names`
is a list of length `p` containing the names of the Fourier components (e.g.
`["const", "cos 1", "sin 1", ...]`). You may assume that `p` is odd.
'''
fourier_basis = t.ones(p, p)
fourier_basis_names = ['Const']
for i in range(1, p // 2 + 1):
# Define each of the cos and sin terms
fourier_basis[2*i-1] = t.cos(2*t.pi*t.arange(p)*i/p)
fourier_basis[2*i] = t.sin(2*t.pi*t.arange(p)*i/p)
fourier_basis_names.extend([f'cos {i}', f'sin {i}'])
# Normalize vectors, and return them
fourier_basis /= fourier_basis.norm(dim=1, keepdim=True)
return fourier_basis, fourier_basis_names
fourier_basis, fourier_basis_names = make_fourier_basis(p)
animate_lines(
fourier_basis,
snapshot_index=fourier_basis_names,
snapshot='Fourier Component',
title='Graphs of Fourier Components (Use Slider)'
)
The Fourier transform (or, specifically, the FFT, or fast Fourier transform) of a sequence of length p can be computed as the dot product of said sequence and each of the basis terms.
def fft1d(x: t.Tensor) -> t.Tensor:
'''
Returns the 1D Fourier transform of `x`,
which can be a vector or a batch of vectors.
x.shape = (..., p)
'''
return x @ fourier_basis.T
For example, we can compute the FFT of the first dimension of the embeddings as a function of the input token. As you can see, what is a very dense and periodic function in the input space becomes very sparse in the Fourier domain. This highlights that there are only a handful of frequencies that explain the periodic behaviors we see in the model.
lines(
W_E[:, [0]].T,
labels=['Embedding dim 0'],
xaxis='Input token',
yaxis='Embedding dimension value',
title='Embedding dimension 0 as a function of the input token'
)
lines(
fft1d(W_E[:, [0]].T.to('cpu')).pow(2),
labels=['Embedding dim 0'],
xaxis='Input token',
yaxis='FFT value squared',
title='FFT of embedding dimension 0 as a function of the input token'
)
Most of the periodic patterns we saw earlier, however, are 2D—how can we study 2D periodic patterns? The answer, perhaps predictably, is the 2D Fourier transform. We can understand it by analogy to the 1D Fourier transform:
- For every pair $(i, j)$ of horizontal and vertical frequencies, we generate a basis term as the outer product of the 1D Fourier basis terms for $i$ and $j$.
- We compute the 2D FFT at position $(i, j)$ as the dot product between our input and the basis term.
def fourier_2d_basis_term(i: int, j: int) -> Float[Tensor, "p p"]:
'''
Returns the 2D Fourier basis term corresponding to the outer product of the
`i`-th component of the 1D Fourier basis in the `x` direction and the `j`-th
component of the 1D Fourier basis in the `y` direction.
Returns a 2D tensor of length `(p, p)`.
'''
return (fourier_basis[i][:, None] * fourier_basis[j][None, :])
def fft2d(tensor: t.Tensor) -> t.Tensor:
'''
Retuns the components of `tensor` in the 2D Fourier basis.
Asumes that the input has shape `(p, p, ...)`, where the
last dimensions (if present) are the batch dims.
Output has the same shape as the input.
'''
# fourier_basis[i] is the i-th basis vector, which we want to multiply along
return einops.einsum(
tensor, fourier_basis, fourier_basis,
"px py ..., i px, j py -> i j ..."
)
Let's test this by generating a sample 2D input by combining a handful of frequencies, and taking its 2D FFT:
example_fn = sum([
fourier_2d_basis_term(4, 6),
fourier_2d_basis_term(14, 46) / 3,
fourier_2d_basis_term(97, 100) / 6
])
inputs_heatmap(example_fn.T, title=f"Example periodic function")
# %%
imshow_fourier(
fft2d(example_fn),
title='Example periodic function in 2D Fourier basis'
)
As you can see by zooming in on those tiny dots on your screen, the dense and periodic 2D pattern is incredibly sparse in the Fourier basis.
Let's turn now to the periodic patterns we observed earlier in our model, and see what they look like in the Fourier basis.
We'll start with the attention score at position 0 for all heads:
inputs_heatmap(
attn_mat[..., 0],
title=f'Attention score for heads at position 0',
animation_frame=2,
animation_name='head'
)
attn_mat_fourier_basis = fft2d(attn_mat_sq.to('cpu'))
# Plot results
imshow_fourier(
attn_mat_fourier_basis[..., 0],
title=f'Attention score for heads at position 0, in Fourier basis',
animation_frame=2,
animation_name='head'
)
Next, we'll look at the neuron activations:
top_k = 3
inputs_heatmap(
neuron_acts_post[:, :top_k],
title=f'Activations for first {top_k} neurons',
animation_frame=2,
animation_name='Neuron'
)
# %%
neuron_acts_post_fourier_basis = fft2d(neuron_acts_post_sq.to('cpu'))
top_k = 10
imshow_fourier(
neuron_acts_post_fourier_basis[..., :top_k],
title=f'Activations for first {top_k} neurons in Fourier basis',
animation_frame=2,
animation_name='Neuron'
)
Next, we can look at the embeddings—here, we're taking the FFT of each embedding dimension as a function of the input token, and summing across embedding dimensions. Note that we take element-wise squares before summing, since we mainly care about the magnitude of the FFT, and not its sign.
This plot shows us that all embedding dimensions are periodic functions of the input token at a handful of specific frequencies.
line(
(fourier_basis @ W_E.to('cpu')).pow(2).sum(1),
hover=fourier_basis_names,
title='Norm of embedding of each Fourier Component',
xaxis='Fourier Component',
yaxis='Norm'
)